package db

import (
	"context"
	"errors"
	"fmt"
	_ "github.com/gogf/gf/contrib/drivers/mysql/v2"
	"github.com/gogf/gf/v2/database/gdb"
	"github.com/gogf/gf/v2/frame/g"
	"github.com/gogf/gf/v2/os/gcache"
	"github.com/gogf/gf/v2/util/gconv"
	"github.com/syyongx/php2go"
	"regexp"
)

func GetUnSafaTable(ctx context.Context, tableName string, dbName ...string) *gdb.Model {
	if len(dbName) > 0 {
		return g.DB(dbName[0]).Model(tableName).Ctx(ctx).Unscoped()
	} else {
		return g.Model(tableName).Ctx(ctx).Unscoped()
	}
}

func SoftDeleteHandler(m *gdb.Model) *gdb.Model {
	// 判断是否有 delete_at 字段
	hasField, _ := m.HasField("delete_at")
	if hasField {
		m.Where("delete_at", 0)
	}
	return m
}

type tableSchema struct {
	ColumnName    string
	ColumnDefault string
	DataType      string
}

// 通过SQL 获取到表结构
func getColumnDefaultBySql(db gdb.DB, tableName string) ([]*tableSchema, error) {
	tableFieldsCacheKey := fmt.Sprintf(
		`mysql_table_fields_%s_%s@group:%s`,
		tableName, "GetColumnDefaultBySql", "666",
	)
	ctx := context.Background()
	data, err := gcache.GetOrSetFuncLock(ctx, tableFieldsCacheKey, func(ctx context.Context) (interface{}, error) {
		db.GetPrefix()
		sql := "SELECT COLUMN_NAME,COLUMN_DEFAULT,DATA_TYPE FROM information_schema.COLUMNS WHERE table_schema = DATABASE()" +
			" AND TABLE_NAME = ? ORDER BY ORDINAL_POSITION"

		var tmps []*tableSchema
		err := db.GetScan(ctx, &tmps, sql, db.GetPrefix()+tableName)
		if err != nil {
			g.Log().Cat("tooleDb").Async().Error(ctx, "GetColumnDefault 失败 ", err.Error())
			return nil, err
		}

		return tmps, nil
	}, 0)
	if err != nil {
		return nil, err
	}
	if data != nil {
		var tmps []*tableSchema
		err := data.Scan(&tmps)
		if err != nil {
			return nil, err
		}
		return tmps, nil
	}
	return nil, errors.New("未知失败")
}

// GetColumnDefault 根据表结构，匹配输入值，如果输入类型不匹配则默认返回该列的默认值
// isDefault = true = 自动匹配， isDefault = false 时，报错
func GetColumnDefault(db gdb.DB, isDefault bool, tableName, columnName string, columnValue interface{}) (interface{}, error) {
	tmps, err := getColumnDefaultBySql(db, tableName)
	if err != nil {
		return nil, err
	}

	columnType := fmt.Sprintf("%T", columnValue)
	switch columnType {
	case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64":
		columnType = "int"
		break
	case "float32", "float64":
		columnType = "float"
		break
	case "string", "json.Number":
		columnType = "string"
		break
	default:
		return nil, errors.New(columnName + " 不支持的类型 " + columnType)
	}
	for _, item := range tmps {
		if item.ColumnName == columnName {
			dt := typeForMysqlToGo(item.DataType)
			switch dt {
			case "string": // 如果是string的列，原样返回
				return columnValue, nil
			case "int64": // 如果是int的列
				if columnType == "string" {
					// 正则判断是否整数
					matchString, _ := regexp.MatchString("^[0-9-][0-9]*$", gconv.String(columnValue))
					if matchString {
						return columnValue, nil
					} else {
						if isDefault {
							// 不是整数，返回默认值
							return item.ColumnDefault, nil
						} else {
							// 报错
							return nil, errors.New(columnName + " 输入类型不匹配，当前类型：int64-string")
						}
					}
				}
				if columnType == "float" {
					if isDefault {
						// 不是整数，返回默认值
						return item.ColumnDefault, nil
					} else {
						// 报错
						return nil, errors.New(columnName + " 输入类型不匹配，当前类型：int64-float")
					}
				} else {
					return columnValue, nil
				}
			case "float64":
				// 是 int 或者 float 型
				if php2go.IsNumeric(columnValue) {
					return columnValue, nil
				} else {
					if isDefault {
						// 不是整数，返回默认值
						return item.ColumnDefault, nil
					} else {
						// 报错
						return nil, errors.New(columnName + " 输入类型不匹配，当前类型：float64")
					}
				}
			}
		}
	}

	return nil, errors.New("未找到匹配的列 " + columnName)
}

func typeForMysqlToGo(typeStr string) string {
	//map for converting mysql type to golang types
	var typeForMysqlToGo = map[string]string{
		"int":                "int64",
		"integer":            "int64",
		"tinyint":            "int64",
		"smallint":           "int64",
		"mediumint":          "int64",
		"bigint":             "int64",
		"int unsigned":       "int64",
		"integer unsigned":   "int64",
		"tinyint unsigned":   "int64",
		"smallint unsigned":  "int64",
		"mediumint unsigned": "int64",
		"bigint unsigned":    "int64",
		"bit":                "int64",
		"bool":               "bool",
		"enum":               "string",
		"set":                "string",
		"varchar":            "string",
		"char":               "string",
		"tinytext":           "string",
		"mediumtext":         "string",
		"text":               "string",
		"longtext":           "string",
		"blob":               "string",
		"tinyblob":           "string",
		"mediumblob":         "string",
		"longblob":           "string",
		"date":               "time.Time", // time.Time or string
		"datetime":           "time.Time", // time.Time or string
		"timestamp":          "time.Time", // time.Time or string
		"time":               "time.Time", // time.Time or string
		"float":              "float64",
		"double":             "float64",
		"decimal":            "float64",
		"binary":             "string",
		"varbinary":          "string",
		"json":               "string",
	}
	return typeForMysqlToGo[typeStr]
}

// FilterColumnParams 过滤接收到的数据
// isDefault = true = 自动匹配， isDefault = false 时，报错
func FilterColumnParams(db gdb.DB, isDefault bool, params map[string]interface{}, tableName string) (map[string]interface{}, error) {
	for k, item := range params {
		columnDefault, err := GetColumnDefault(db, isDefault, tableName, k, item)
		if err != nil {
			return nil, err
		}
		params[k] = columnDefault
	}

	return params, nil
}
